""" ShuffleNetV2 backbone."""

import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.ops.operations as P

def conv_block(in_channels,
               out_channels,
               kernel_size,
               stride,
               pad_mode,
               padding,
               num_features,
               group=1,
               has_bias=False,
               with_bn=True,
               with_relu=True):
    """Get a conv2d batchnorm and relu layer"""

    block = [nn.Conv2d(in_channels=in_channels,
                       out_channels=out_channels,
                       kernel_size=kernel_size,
                       stride=stride,
                       pad_mode=pad_mode,
                       padding=padding,
                       group=group,
                       has_bias=has_bias)]
    if with_bn:
        block.append(nn.BatchNorm2d(num_features=num_features, momentum=0.9))
    if with_relu:
        block.append(nn.ReLU())
    return block


class ShuffleV2Block(nn.Cell):
    """
    ShufflleNetv2 block definition.

    Args:
        in_channels (int): Input channel.
        out_channels (int): Output channel.
        mid_channels (int): Middle channel.
        kernel_size (int): Input kernel size.
        stride (int): Stride size for the convolutional layer.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> shuffleV2Block(256, 512, 512, kernel_size=3, stride=1):
    """

    def __init__(self, in_channels, out_channels, mid_channels, kernel_size, stride):
        super(ShuffleV2Block, self).__init__()
        self.stride = stride
        # assert stride in [1, 2]

        self.mid_channels = mid_channels
        self.kernel_size = kernel_size
        pad = kernel_size // 2
        self.pad = pad
        self.in_channels = in_channels
        main_out_channels = out_channels - in_channels

        # pw
        branch_main = conv_block(in_channels=in_channels, out_channels=mid_channels,
                                 kernel_size=1, stride=1, pad_mode='pad', padding=0,
                                 num_features=mid_channels, has_bias=False)
        # dw
        branch_main.extend(conv_block(in_channels=mid_channels, out_channels=mid_channels,
                                      kernel_size=kernel_size, stride=stride, pad_mode='pad',
                                      padding=pad, num_features=mid_channels, group=mid_channels,
                                      has_bias=False, with_relu=False))
        # pw-linear
        branch_main.extend(conv_block(in_channels=mid_channels, out_channels=main_out_channels,
                                      kernel_size=1, stride=1, pad_mode='pad', padding=0,
                                      num_features=main_out_channels, has_bias=False))
        self.branch_main = nn.SequentialCell(branch_main)

        if stride == 2:
            branch_proj = conv_block(in_channels=in_channels, out_channels=in_channels,
                                     kernel_size=kernel_size, stride=stride, pad_mode='pad',
                                     padding=pad, num_features=in_channels, group=in_channels,
                                     has_bias=False, with_relu=False)
            # pw-linear
            branch_main.extend(conv_block(in_channels=in_channels, out_channels=in_channels,
                                          kernel_size=1, stride=1, pad_mode='pad', padding=0,
                                          num_features=in_channels, has_bias=False))
            self.branch_proj = nn.SequentialCell(branch_proj)
        else:
            self.branch_proj = None

    def construct(self, x_old):
        if self.stride == 1:
            x_proj, x = self.channel_shuffle(x_old)
            return P.Concat(1)((x_proj, self.branch_main(x)))
        if self.stride == 2:
            x_proj = x_old
            x = x_old
            return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x)))
        return None

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = P.Shape()(x)
        x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,))
        x = P.Transpose()(x, (1, 0, 2,))
        x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,))
        return x[0], x[1]


class ShuffleNetV2(nn.Cell):
    """
    ShuffleNetV2 architecture.

    Args:
        model_size (string): size of model. Default is '1.0x'.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ShuffleNetV2(model_size='1.5x')
    """

    def __init__(self, model_size='1.0x'):
        super(ShuffleNetV2, self).__init__()
        print('model size is ', model_size)

        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if model_size == '0.5x':
            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
        elif model_size == '1.0x':
            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
        elif model_size == '1.5x':
            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
        elif model_size == '2.0x':
            self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
        else:
            raise NotImplementedError

        # building first layer
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(conv_block(in_channels=3,
                                                       out_channels=input_channel,
                                                       kernel_size=3, stride=2, pad_mode='pad', padding=1,
                                                       num_features=input_channel, has_bias=False))

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

        self.features = []
        for idxstage in range(len(self.stage_repeats)):
            num_repeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]

            for i in range(num_repeat):
                if i == 0:
                    self.features.append(ShuffleV2Block(input_channel, output_channel,
                                                        mid_channels=output_channel // 2, kernel_size=3, stride=2))
                else:
                    self.features.append(ShuffleV2Block(input_channel // 2, output_channel,
                                                        mid_channels=output_channel // 2, kernel_size=3, stride=1))

                input_channel = output_channel

        self.features = nn.SequentialCell(self.features)

        self.conv_last = nn.SequentialCell(conv_block(in_channels=input_channel,
                                                      out_channels=self.stage_out_channels[-1],
                                                      kernel_size=1, stride=1, pad_mode='pad', padding=0,
                                                      num_features=self.stage_out_channels[-1], has_bias=False))

        self._initialize_weights()

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.conv_last(x)
        return x

    def _initialize_weights(self):
        """
        Initialize weights.

        Args:

        Returns:
            None.

        Examples:
            >>> _initialize_weights()
        """
        for name, m in self.cells_and_names():
            if isinstance(m, nn.Conv2d):
                if 'first' in name:
                    m.weight.set_data(Tensor(np.random.normal(0, 0.01,
                                                              m.weight.data.shape).astype("float32")))
                else:
                    m.weight.set_data(Tensor(np.random.normal(0, 1.0 / m.weight.data.shape[1],
                                                              m.weight.data.shape).astype("float32")))
            if isinstance(m, nn.Dense):
                m.weight.set_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))

    @property
    def get_features(self):
        return self.features
